module Spec.AES.Test

open Lib.IntTypes
open Lib.RawIntTypes
open Lib.Sequence
open Spec.AES
open Lib.LoopCombinators

#set-options "--lax"

let test_key = List.Tot.map u8 [
  0x2b; 0x7e; 0x15; 0x16; 0x28; 0xae; 0xd2; 0xa6;
  0xab; 0xf7; 0x15; 0x88; 0x09; 0xcf; 0x4f; 0x3c
]

let test_nonce = List.Tot.map u8 [
  0xf0; 0xf1; 0xf2 ; 0xf3; 0xf4; 0xf5; 0xf6; 0xf7; 0xf8; 0xf9; 0xfa; 0xfb
]

let test_counter = 0xfcfdfeff

let test_plaintext = List.Tot.map u8 [
  0x6b; 0xc1; 0xbe; 0xe2; 0x2e; 0x40; 0x9f; 0x96; 0xe9; 0x3d; 0x7e; 0x11; 0x73; 0x93; 0x17; 0x2a
]

let test_ciphertext = List.Tot.map u8 [
  0x87; 0x4d; 0x61; 0x91; 0xb6; 0x20; 0xe3; 0x26; 0x1b; 0xef; 0x68; 0x64; 0x99; 0x0d; 0xb6; 0xce
]
(* From RFC 3686 *)
let test_key1 = List.Tot.map u8 [
0xAE; 0x68; 0x52; 0xF8; 0x12; 0x10; 0x67; 0xCC; 0x4B; 0xF7; 0xA5; 0x76; 0x55; 0x77; 0xF3; 0x9E
]

let test_nonce1 = List.Tot.map u8 [
  0x00; 0x00; 0x00 ; 0x30;  0x00; 0x00; 0x00 ; 0x00;  0x00; 0x00; 0x00 ; 0x00;  0x00; 0x00; 0x00 ; 0x00
]

let test_counter1 = 1

let test_plaintext1 = List.Tot.map u8 [
 0x53; 0x69; 0x6E; 0x67; 0x6C; 0x65; 0x20; 0x62; 0x6C; 0x6F; 0x63; 0x6B; 0x20; 0x6D; 0x73; 0x67
]

let test_ciphertext1 = List.Tot.map u8 [
 0xE4; 0x09; 0x5D; 0x4F; 0xB7; 0xA7; 0xB3; 0x79; 0x2D; 0x61; 0x75; 0xA3; 0x26; 0x13; 0x11; 0xB8
]

let test_key2 = List.Tot.map u8 [
 0x7E; 0x24; 0x06; 0x78; 0x17; 0xFA; 0xE0; 0xD7; 0x43; 0xD6; 0xCE; 0x1F; 0x32; 0x53; 0x91; 0x63
]

let test_nonce2 = List.Tot.map u8 [
 0x00; 0x6C; 0xB6; 0xDB; 0xC0; 0x54; 0x3B; 0x59; 0xDA; 0x48; 0xD9; 0x0B
]

let test_counter2 = 1

let test_plaintext2 = List.Tot.map u8 [
 0x00; 0x01; 0x02; 0x03; 0x04; 0x05; 0x06; 0x07; 0x08; 0x09; 0x0A; 0x0B; 0x0C; 0x0D; 0x0E; 0x0F; 0x10; 0x11; 0x12; 0x13; 0x14; 0x15; 0x16; 0x17; 0x18; 0x19; 0x1A; 0x1B; 0x1C; 0x1D; 0x1E; 0x1F
]

let test_ciphertext2 = List.Tot.map u8 [
 0x51; 0x04; 0xA1; 0x06; 0x16; 0x8A; 0x72; 0xD9; 0x79; 0x0D; 0x41; 0xEE; 0x8E; 0xDA; 0xD3; 0x88; 0xEB; 0x2E; 0x1E; 0xFC; 0x46; 0xDA; 0x57; 0xC8; 0xFC; 0xE6; 0x30; 0xDF; 0x91; 0x41; 0xBE; 0x28
]

let test1_key_block = List.Tot.map u8 [
 0x80; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00
]

let test1_plaintext_block = List.Tot.map u8 [
 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00
]

let test1_ciphertext_block = List.Tot.map u8 [
 0x0e; 0xdd; 0x33; 0xd3; 0xc6; 0x21; 0xe5; 0x46; 0x45; 0x5b; 0xd8; 0xba; 0x14; 0x18; 0xbe; 0xc8
]

let test2_key_block = List.Tot.map u8 [
 0xff; 0xff; 0xff; 0xff; 0xff; 0xf0; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00
]

let test2_plaintext_block = List.Tot.map u8 [
 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00
]

let test2_ciphertext_block = List.Tot.map u8 [
 0xe6; 0xc4; 0x80; 0x7a; 0xe1; 0x1f; 0x36; 0xf0; 0x91; 0xc5; 0x7d; 0x9f; 0xb6; 0x85; 0x48; 0xd1
]

let test3_key_block = List.Tot.map u8 [
  0xfe; 0xff; 0xe9; 0x92; 0x86; 0x65; 0x73; 0x1c; 0x6d; 0x6a; 0x8f; 0x94; 0x67; 0x30; 0x83; 0x08
]

let test3_ciphertext_block = List.Tot.map u8 [
 0xb8; 0x3b; 0x53; 0x37; 0x08; 0xbf; 0x53; 0x5d; 0x0a; 0xa6; 0xe5; 0x29; 0x80; 0xd5; 0x3b; 0x78
]

let test1_input_key1 = of_list (List.Tot.map u8 [
  0x60; 0x3d; 0xeb; 0x10; 0x15; 0xca; 0x71; 0xbe;
  0x2b; 0x73; 0xae; 0xf0; 0x85; 0x7d; 0x77; 0x81;
  0x1f; 0x35; 0x2c; 0x07; 0x3b; 0x61; 0x08; 0xd7;
  0x2d; 0x98; 0x10; 0xa3; 0x09; 0x14; 0xdf; 0xf4
])

let test1_output_expanded = of_list(List.Tot.map u8 [
  0x60; 0x3d; 0xeb; 0x10; 0x15; 0xca; 0x71; 0xbe;
  0x2b; 0x73; 0xae; 0xf0; 0x85; 0x7d; 0x77; 0x81;
  0x1f; 0x35; 0x2c; 0x07; 0x3b; 0x61; 0x08; 0xd7;
  0x2d; 0x98; 0x10; 0xa3; 0x09; 0x14; 0xdf; 0xf4;
  0x9b; 0xa3; 0x54; 0x11; 0x8e; 0x69; 0x25; 0xaf;
  0xa5; 0x1a; 0x8b; 0x5f; 0x20; 0x67; 0xfc; 0xde;
  0xa8; 0xb0; 0x9c; 0x1a; 0x93; 0xd1; 0x94; 0xcd;
  0xbe; 0x49; 0x84; 0x6e; 0xb7; 0x5d; 0x5b; 0x9a;
  0xd5; 0x9a; 0xec; 0xb8; 0x5b; 0xf3; 0xc9; 0x17;
  0xfe; 0xe9; 0x42; 0x48; 0xde; 0x8e; 0xbe; 0x96;
  0xb5; 0xa9; 0x32; 0x8a; 0x26; 0x78; 0xa6; 0x47;
  0x98; 0x31; 0x22; 0x29; 0x2f; 0x6c; 0x79; 0xb3;
  0x81; 0x2c; 0x81; 0xad; 0xda; 0xdf; 0x48; 0xba;
  0x24; 0x36; 0x0a; 0xf2; 0xfa; 0xb8; 0xb4; 0x64;
  0x98; 0xc5; 0xbf; 0xc9; 0xbe; 0xbd; 0x19; 0x8e;
  0x26; 0x8c; 0x3b; 0xa7; 0x09; 0xe0; 0x42; 0x14;
  0x68; 0x00; 0x7b; 0xac; 0xb2; 0xdf; 0x33; 0x16;
  0x96; 0xe9; 0x39; 0xe4; 0x6c; 0x51; 0x8d; 0x80;
  0xc8; 0x14; 0xe2; 0x04; 0x76; 0xa9; 0xfb; 0x8a;
  0x50; 0x25; 0xc0; 0x2d; 0x59; 0xc5; 0x82; 0x39;
  0xde; 0x13; 0x69; 0x67; 0x6c; 0xcc; 0x5a; 0x71;
  0xfa; 0x25; 0x63; 0x95; 0x96; 0x74; 0xee; 0x15;
  0x58; 0x86; 0xca; 0x5d; 0x2e; 0x2f; 0x31; 0xd7;
  0x7e; 0x0a; 0xf1; 0xfa; 0x27; 0xcf; 0x73; 0xc3;
  0x74; 0x9c; 0x47; 0xab; 0x18; 0x50; 0x1d; 0xda;
  0xe2; 0x75; 0x7e; 0x4f; 0x74; 0x01; 0x90; 0x5a;
  0xca; 0xfa; 0xaa; 0xe3; 0xe4; 0xd5; 0x9b; 0x34;
  0x9a; 0xdf; 0x6a; 0xce; 0xbd; 0x10; 0x19; 0x0d;
  0xfe; 0x48; 0x90; 0xd1; 0xe6; 0x18; 0x8d; 0x0b;
  0x04; 0x6d; 0xf3; 0x44; 0x70; 0x6c; 0x63; 0x1e
])

let test2_input_key = of_list (List.Tot.map u8 [
  0x00; 0x01; 0x02; 0x03; 0x04; 0x05; 0x06; 0x07;
  0x08; 0x09; 0x0a; 0x0b; 0x0c; 0x0d; 0x0e; 0x0f;
  0x10; 0x11; 0x12; 0x13; 0x14; 0x15; 0x16; 0x17;
  0x18; 0x19; 0x1a; 0x1b; 0x1c; 0x1d; 0x1e; 0x1f
])

let test2_input_plaintext = of_list (List.Tot.map u8 [
  0x00; 0x11; 0x22; 0x33; 0x44; 0x55; 0x66; 0x77;
  0x88; 0x99; 0xaa; 0xbb; 0xcc; 0xdd; 0xee; 0xff
])

let test2_output_ciphertext = of_list (List.Tot.map u8 [
  0x8e; 0xa2; 0xb7; 0xca; 0x51; 0x67; 0x45; 0xbf;
  0xea; 0xfc; 0x49; 0x90; 0x4b; 0x49; 0x60; 0x89
])

let test3_input_key = of_list (List.Tot.map u8 [
  0xc4; 0x7b; 0x02; 0x94; 0xdb; 0xbb; 0xee; 0x0f;
  0xec; 0x47; 0x57; 0xf2; 0x2f; 0xfe; 0xee; 0x35;
  0x87; 0xca; 0x47; 0x30; 0xc3; 0xd3; 0x3b; 0x69;
  0x1d; 0xf3; 0x8b; 0xab; 0x07; 0x6b; 0xc5; 0x58
])

let test3_input_plaintext = of_list (List.Tot.map u8 [
  0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00;
  0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00
])

let test3_output_ciphertext = of_list (List.Tot.map u8 [
  0x46; 0xf2; 0xfb; 0x34; 0x2d; 0x6f; 0x0a; 0xb4;
  0x77; 0x47; 0x6f; 0xc5; 0x01; 0x24; 0x2c; 0x5f
])

let test4_input_key = of_list (List.Tot.map u8 [
  0xcc; 0xd1; 0xbc; 0x3c; 0x65; 0x9c; 0xd3; 0xc5;
  0x9b; 0xc4; 0x37; 0x48; 0x4e; 0x3c; 0x5c; 0x72;
  0x44; 0x41; 0xda; 0x8d; 0x6e; 0x90; 0xce; 0x55;
  0x6c; 0xd5; 0x7d; 0x07; 0x52; 0x66; 0x3b; 0xbc
])

let test4_input_plaintext = of_list (List.Tot.map u8 [
  0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00;
  0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00; 0x00
])

let test4_output_ciphertext = of_list (List.Tot.map u8 [
  0x30; 0x4f; 0x81; 0xab; 0x61; 0xa8; 0x0c; 0x2e;
  0x74; 0x3b; 0x94; 0xd5; 0x00; 0x2a; 0x12; 0x6b
])



let test_compare_buffers (msg:string) (expected:seq uint8) (computed:seq uint8) =
  IO.print_string "\n";
  IO.print_string msg;
  IO.print_string "\nexpected (";
  IO.print_uint32_dec (UInt32.uint_to_t (length expected));
  IO.print_string "):\n";
  FStar.List.iter (fun a -> IO.print_uint8_hex_pad (u8_to_UInt8 a)) (to_list expected);
  IO.print_string "\n";
  IO.print_string "computed (";
  IO.print_uint32_dec (UInt32.uint_to_t (length computed));
  IO.print_string "):\n";
  FStar.List.iter (fun a -> IO.print_uint8_hex_pad (u8_to_UInt8 a)) (to_list computed);
  IO.print_string "\n";
  let result =
    for_all2 #uint8 #uint8 #(length computed) (fun x y -> uint_to_nat #U8 x = uint_to_nat #U8 y)
      computed expected
  in
  if result then begin IO.print_string "\nSuccess !\n"; true end
  else begin IO.print_string "\nFailed !\n"; false end

let test() : FStar.All.ML bool =
  let seq = create 256 (u8 0) in
  let seqi = repeati #(lseq uint8 256) 256 (fun i s -> s.[i] <- u8 i) seq in
  (*
  let inv = map (fun s -> from_elem (finv (to_elem s))) seqi in
  IO.print_string "inv i:     \n";
  FStar.List.iter (fun a -> IO.print_string (UInt8.to_string (u8_to_UInt8 a)); IO.print_string " ; ") (to_list #uint8 #256 inv);
  IO.print_string "\n";
  *)
  let seqsbox = map (fun s -> sub_byte s) seqi in
  IO.print_string "sbox i:     \n";
  FStar.List.iter (fun a -> IO.print_string (UInt8.to_string (u8_to_UInt8 a)); IO.print_string " ; ") (to_list #uint8 seqsbox);
  IO.print_string "\n";
(*
  let seqsbox_16 = map (fun s -> sbox_bp_16 s) seqi in
  IO.print_string "sbox bp_i i:\n";
  FStar.List.iter (fun a -> IO.print_string (UInt8.to_string (u8_to_UInt8 a)); IO.print_string " ; ") (to_list #uint8 seqsbox_16);
  IO.print_string "\n";
  *)
  let key = of_list test_key in
  let nonce = of_list test_nonce in
  let counter = test_counter in
  let plain = of_list test_plaintext in
  let expected = of_list test_ciphertext in
  let cip = aes128_ctr_encrypt_bytes key 12 nonce counter plain in
//  let cip = aes128_block key nonce counter in
//  let cip = map2 (logxor #U8) cip plain in
  IO.print_string "aes_cip computed 0:\n";
  FStar.List.iter (fun a -> IO.print_string (UInt8.to_string (u8_to_UInt8 a)); IO.print_string " ; ") (to_list #uint8 cip);
  IO.print_string "\n";
  IO.print_string "aes_cip expected 0:\n";
  FStar.List.iter (fun a -> IO.print_string (UInt8.to_string (u8_to_UInt8 a)); IO.print_string " ; ") (to_list #uint8 expected);
  IO.print_string "\n";
  let key = of_list test_key1 in
  let nonce = of_list test_nonce1 in
  let counter = test_counter1 in
  let plain = of_list test_plaintext1 in
  let expected = of_list test_ciphertext1 in
  let cip = aes128_ctr_encrypt_bytes key 12 nonce counter plain in
//  let cip = aes128_block key nonce counter in
//  let cip = map2 (logxor #U8) cip plain in
  IO.print_string "aes_cip computed 1:\n";
  FStar.List.iter (fun a -> IO.print_string (UInt8.to_string (u8_to_UInt8 a)); IO.print_string " ; ") (to_list #uint8 cip);
  IO.print_string "\n";
  IO.print_string "aes_cip expected 1:\n";
  FStar.List.iter (fun a -> IO.print_string (UInt8.to_string (u8_to_UInt8 a)); IO.print_string " ; ") (to_list #uint8 expected);
  IO.print_string "\n";
  let key = of_list test_key2 in
  let nonce = of_list test_nonce2 in
  let counter = test_counter2 in
  let plain = of_list test_plaintext2 in
  let expected = of_list test_ciphertext2 in
  let cip = aes128_ctr_encrypt_bytes key 12 nonce counter plain in
//  let cip = aes128_block key nonce counter in
//  let cip = map2 (logxor #U8) cip plain in
  IO.print_string "aes_cip computed 2:\n";
  FStar.List.iter (fun a -> IO.print_string (UInt8.to_string (u8_to_UInt8 a)); IO.print_string " ; ") (to_list #uint8 cip);
  IO.print_string "\n";
  IO.print_string "aes_cip expected 2:\n";
  FStar.List.iter (fun a -> IO.print_string (UInt8.to_string (u8_to_UInt8 a)); IO.print_string " ; ") (to_list #uint8 expected);
  IO.print_string "\n";

  (* test plain AES block function *)
  let key = of_list test1_key_block in
  let plain = of_list test1_plaintext_block in
  let expected = of_list test1_ciphertext_block in
  let xkey = aes_key_expansion AES128 key in
  let cip = aes_encrypt_block AES128 xkey plain in
  IO.print_string "aes_cip computed 3:\n";
  FStar.List.iter (fun a -> IO.print_string (UInt8.to_string (u8_to_UInt8 a)); IO.print_string " ; ") (to_list #uint8 cip);
  IO.print_string "\n";
  IO.print_string "aes_cip expected 3:\n";
  FStar.List.iter (fun a -> IO.print_string (UInt8.to_string (u8_to_UInt8 a)); IO.print_string " ; ") (to_list #uint8 expected);
  IO.print_string "\n";
  let key = of_list test2_key_block in
  let plain = of_list test2_plaintext_block in
  let expected = of_list test2_ciphertext_block in
  let xkey = aes_key_expansion AES128 key in
  let cip = aes_encrypt_block AES128 xkey plain in
  IO.print_string "aes_cip computed 4:\n";
  FStar.List.iter (fun a -> IO.print_string (UInt8.to_string (u8_to_UInt8 a)); IO.print_string " ; ") (to_list #uint8 cip);
  IO.print_string "\n";
  IO.print_string "aes_cip expected 4:\n";
  FStar.List.iter (fun a -> IO.print_string (UInt8.to_string (u8_to_UInt8 a)); IO.print_string " ; ") (to_list #uint8 expected);
  IO.print_string "\n";
  let key = of_list test3_key_block in
  let plain = of_list test2_plaintext_block in
  let expected = of_list test3_ciphertext_block in
  let xkey = aes_key_expansion AES128 key in
  let cip = aes_encrypt_block AES128 xkey plain in
  IO.print_string "aes_cip computed 5:\n";
  FStar.List.iter (fun a -> IO.print_string (UInt8.to_string (u8_to_UInt8 a)); IO.print_string " ; ") (to_list #uint8 cip);
  IO.print_string "\n";
  IO.print_string "aes_cip expected 5:\n";
  FStar.List.iter (fun a -> IO.print_string (UInt8.to_string (u8_to_UInt8 a)); IO.print_string " ; ") (to_list #uint8 expected);
  IO.print_string "\n";
  IO.print_string "AES-256 TESTS 6:\n";
  let computed1 = aes_key_expansion AES256 test1_input_key1 in
  let result1 = test_compare_buffers "TEST1: key expansion" test1_output_expanded computed1 in
  let test2_xkey = aes_key_expansion AES256 test2_input_key in
  let test2_computed = aes_encrypt_block AES256 test2_xkey test2_input_plaintext in
  let result2 = test_compare_buffers "TEST2: cipher" test2_output_ciphertext test2_computed in
  let test2_xkey = aes_dec_key_expansion AES256 test2_input_key in
  let test2_inv_computed = aes_decrypt_block AES256 test2_xkey test2_output_ciphertext in
  let result3 = test_compare_buffers "TEST3: inv_cipher" test2_input_plaintext test2_inv_computed in
  let test3_xkey = aes_key_expansion AES256 test3_input_key in
  let test3_computed = aes_encrypt_block AES256 test3_xkey test3_input_plaintext in
  let result4 = test_compare_buffers "TEST4: cipher" test3_output_ciphertext test3_computed in
  let test3_xkey = aes_dec_key_expansion AES256 test3_input_key in
  let test3_inv_computed = aes_decrypt_block AES256 test3_xkey test3_output_ciphertext in
  let result5 = test_compare_buffers "TEST5: inv_cipher" test3_input_plaintext test3_inv_computed in
  let test4_xkey = aes_key_expansion AES256 test4_input_key in
  let test4_computed = aes_encrypt_block AES256 test4_xkey test4_input_plaintext in
  let result6 = test_compare_buffers "TEST5: cipher" test4_output_ciphertext test4_computed in
  let test4_xkey = aes_dec_key_expansion AES256 test4_input_key in
  let test4_inv_computed = aes_decrypt_block AES256 test4_xkey  test4_output_ciphertext in
  let result7 = test_compare_buffers "TEST6: inv_cipher" test4_input_plaintext test4_inv_computed in
  result1 && result2 && result3 && result4 && result5 && result6 && result7
